#This script uses GPT-4o to classify articles as critical or not
#It produces two csv files with the classified articles for Egypt and Tunisia
#NOTE: This script will not produce exactly the same results as the paper
#Due to the stochastic nature of the language model
import pandas as pd
from openai import OpenAI
import json

client = OpenAI(api_key='') #Replace with your own API key

def classify_content(content, output_variables, function_description, function_name):
    function = {
        "name": function_name,
        "description": function_description,
        "parameters": output_variables,
    }
    message_content = f"{function_description}: \n\n'{content}'"

    error_response = {var: pd.NA for var in output_variables['properties'].keys()}

    try:
        response = client.chat.completions.create(
            model="gpt-4o",
            messages=[{"role": "user", "content": message_content}],
            functions=[function],
            max_tokens=1000
        )
        # Check if function call and arguments exist
        if response.choices and response.choices[0].message and response.choices[0].message.function_call:
            function_call_arguments = response.choices[0].message.function_call.arguments
            if function_call_arguments:  # Check if arguments is not empty
                # Parse the function call arguments as JSON
                result = json.loads(function_call_arguments)
            else:
                print("Function call arguments are empty.")
                return error_response
        else:
            print("Expected function call data not found in response.")
            return error_response
        
    except json.JSONDecodeError:
        print("Failed to decode JSON from function call arguments.")
        return error_response
    except Exception as e:
        print(f"An error occurred: {e}")
        return error_response

    return result

def process_dataset(dfsamp, classify_function, output_variables, function_description, function_name, text_column):
    if dfsamp is None:
        print("No data to process.")
        return

    # Ensure new columns for output exist in DataFrame
    for key in output_variables['properties'].keys():
        if key + '_gpt' not in dfsamp.columns:
            dfsamp[key + '_gpt'] = pd.NA

    total_rows = len(dfsamp)
    for index, row in dfsamp.iterrows():
        print(f"Processing {index + 1}/{total_rows}...")
        content = row[text_column]
        result = classify_function(content, output_variables, function_description, function_name)

        # Assuming result is a dictionary with the keys matching output_variables 'properties' keys
        for key in output_variables['properties'].keys():
            dfsamp.at[index, key + '_gpt'] = result.get(key, pd.NA)

    print("\nProcessing complete.")
    return dfsamp


function_name_arabress= "classify_article" # Note function name cannot contain spaces
function_description_arabress = "This task involves reading some sentences from Arabic-language news articles. In each piece of text, there will be some discussion of an individual called 'TARGETWORD'. Your task is to code on a scale of 1-9 how critical that article is of this individual or group of individuals.."
output_variables_arabress = {
    "type": "object",
    "properties": {
        "article": {
            "type": "string",
            "description": "The content of the news article to be classified."
        },
        "scorefac": {
            "type": "string",
            "enum": ["critical", "uncritical", "neutral", "not-applicable"],
            "description": "Read the news article and decide whether it is critical toward 'TARGETWORD' or not. You can choose one of four categories, where 'critical' means the article is critical; 'uncritical' means it is not critical, 'neutral' means it is neither critical nor uncritical and 'not-applicable' means the article is not about 'TARGETWORD' and is instead talking about something else.."
        },
        "scoreint": {
            "type": "integer",
            "enum": [
                1, 2, 3, 4, 5, 6, 7, 8, 9
            ],
            "description": "Read the news article and decide whether it is critical toward 'TARGETWORD' or not on a scale of 1 to 9 where 1 indicates not critical at all and 9 indicates very critical, estimate how critical this article is of 'TARGETWORD'. If the article is not-applicable, code as 12."
        }
    },
    "required": ["article", "scorefac", "scoreint"]
}

# masress 

filename = "data/output/cos_sims_dsl/masressdata_sample.csv"
dfsamp = pd.read_csv(filename)
dfsamp = dfsamp.reset_index(drop=True)

text_column = 'content'  # The name of the column in the DataFrame containing the content to classify

processed_dfsamp = process_dataset(
    dfsamp, 
    classify_content, 
    output_variables_arabress, 
    function_description_arabress, 
    function_name_arabress, 
    text_column
)

processed_dfsamp.to_csv("data/output/cos_sims_dsl/masressdata_sample_gpt.csv")

# turess 

filename = "data/output/cos_sims_dsl/turessdata_sample.csv"
dfsamp = pd.read_csv(filename)
dfsamp = dfsamp.reset_index(drop=True)

text_column = 'content'  # The name of the column in the DataFrame containing the content to classify

processed_dfsamp = process_dataset(
    dfsamp, 
    classify_content, 
    output_variables_arabress, 
    function_description_arabress, 
    function_name_arabress, 
    text_column
)

processed_dfsamp.to_csv("data/output/cos_sims_dsl/turessdata_sample_gpt.csv")